import numpy as np
import pandas as pd
import collections
import torch
from sklearn.model_selection import StratifiedKFold
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.metrics import balanced_accuracy_score, f1_score
import pickle

torch.cuda.set_device(3)
df_data = pd.read_csv('./feature.GB18030.v10.csv', encoding='GB18030')
df_features = pd.read_csv('./feature_names.csv')
modes_dict = df_data.mode().iloc[0].to_dict()
df_data.fillna(modes_dict, inplace=True)
feature_names = df_features['feature'].tolist()
df_data.dropna(subset=feature_names+['plain_text'], inplace=True)

min_txt_len = 1

# filter = df_data['plain_text'] != '转发微博'
# df_data = df_data[filter]
filter2 = df_data['plain_text'].apply(lambda x: len(x) >= min_txt_len)
df_data = df_data[filter2]

all_feaures = df_data[feature_names]
all_feaures_trans = all_feaures.apply(lambda x: pd.factorize(x)[0] if x.dtype == 'object' or x.dtype == 'bool' else x)
all_feature_values = all_feaures_trans.values

binary_labels = []
split_x_high = 9.0
split_y_high = 6.0

split_x_low = 2.0
split_y_low = 2.0

for idx, row in df_data.iterrows():
    if df_data['comment_like_num'][idx] <= 0 and df_data['child_comment_num'][idx] <= 0:
        binary_labels.append(0)
    else:
        binary_labels.append(1)

binary_labels = np.asarray(binary_labels)
print(collections.Counter(binary_labels))

batch_size = 64
max_epochs = 2000
patience = 1000

result_dic = []

kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
for train_ids, test_ids in kf.split(all_feature_values, y=binary_labels):
    train_labels = binary_labels[train_ids]
    test_labels = binary_labels[test_ids]
    train_data = all_feature_values[train_ids]
    test_data = all_feature_values[test_ids]
    print(collections.Counter(test_labels))
    clf = TabNetClassifier()
    clf.fit(
        X_train=train_data, y_train=train_labels,
        eval_set=[(test_data, test_labels)],
        eval_metric=['balanced_accuracy'],
        max_epochs=max_epochs, patience=patience,
        batch_size=batch_size, virtual_batch_size=batch_size,
        num_workers=4,
        weights=1
    )
    y_preds = clf.predict(test_data)
    explain_mat, masks = clf.explain(test_data)
    result_dic.append({'y_true': test_labels, 'y_preds': y_preds, 'feature_imp': clf.feature_importances_, 'explain': (explain_mat, masks)})

with open(f'./results/tabnet_t01_binary_batch{batch_size}_epoch_{max_epochs}.pkl', 'wb') as f:
    pickle.dump(result_dic, f)
